{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "GD5Omr5EfWgb"
   },
   "source": [
    "# Date Generator\n",
    "\n",
    "generate synthetic data when given scheme, business problem description, model, number of records, file name, file type, and environment\n",
    "\n",
    "# Available models\n",
    "  Model API:\n",
    "\n",
    "    1. gpt-4o-mini\n",
    "    2. claude-3-haiku-20240307\n",
    "    3. gemini-2.0-flash\n",
    "    4. deepseek-chat\"\n",
    "\n",
    "  HuggingFace API:\n",
    "\n",
    "    5. meta-llama/Meta-Llama-3.1-8B-Instruct\n",
    "\n",
    "\n",
    "# Available environment\n",
    "\n",
    "Colab: set up HF token and API keys in Colab secret section\n",
    "\n",
    "Local: set up HF token and API keys in .env file\n",
    "\n",
    "\n",
    "\n",
    "### *** This project is developed based on the idea of 'week3/community-contributuins/Week3-Dataset_Generator-DP'. Really appreciate it! Then, the project is improved to run both on Colab or locally, and integrate HuggingFace API"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "4FiCnE0MmU56"
   },
   "outputs": [],
   "source": [
    "!pip install -q --upgrade torch==2.5.1+cu124 torchvision==0.20.1+cu124 torchaudio==2.5.1+cu124 --index-url https://download.pytorch.org/whl/cu124\n",
    "!pip install -q requests bitsandbytes==0.46.0 transformers==4.48.3 accelerate==1.3.0\n",
    "!pip install anthropic dotenv pyarrow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "JeyKw5guoH3r"
   },
   "outputs": [],
   "source": [
    "# imports\n",
    "\n",
    "import os\n",
    "import requests\n",
    "from IPython.display import Markdown, display, update_display\n",
    "from openai import OpenAI\n",
    "from huggingface_hub import login\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer, BitsAndBytesConfig\n",
    "from bs4 import BeautifulSoup\n",
    "from typing import List\n",
    "import google.generativeai\n",
    "import anthropic\n",
    "from itertools import chain\n",
    "from dotenv import load_dotenv\n",
    "import gradio as gr\n",
    "import json\n",
    "import pandas as pd\n",
    "import random\n",
    "import re\n",
    "import subprocess\n",
    "import pyarrow as pa\n",
    "import torch\n",
    "import gc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "7UyjFdRZoIAS"
   },
   "outputs": [],
   "source": [
    "# --- Schema Definition ---\n",
    "SCHEMA = [\n",
    "    (\"Name\", \"TEXT\", '\"Northern Cafe\"'),\n",
    "    (\"Location\", \"TEXT\", '\"2904 S Figueroa St, Los Angeles, CA 90007\"'),\n",
    "    (\"Type\", \"TEXT\", 'One of [\"Chinese\",\"Mexico\",\"French\",\"Korean\",\"Italy\"] or other potential types'),\n",
    "    (\"Average Price\", \"TEXT\", '\"$30\", or \"--\" if unkown'),\n",
    "    (\"History/Age\", \"INT\", 'integer age of resturant, e.g., 7'),\n",
    "    (\"Menu\", \"Array\", '[\"Beef Noodle\", \"Fried Rice\", \"Dumpling\", ...]'),\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "jXcTQATLoICV"
   },
   "outputs": [],
   "source": [
    "# Default schema text for the textbox\n",
    "DEFAULT_SCHEMA_TEXT = \"\\n\".join([f\"{i+1}. {col[0]} ({col[1]}) Example: {col[2]}\" for i, col in enumerate(SCHEMA)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "4Irf5JV3oIEe"
   },
   "outputs": [],
   "source": [
    "# Available models\n",
    "MODELS = [\n",
    "    \"gpt-4o-mini\",\n",
    "    \"claude-3-haiku-20240307\",\n",
    "    \"gemini-2.0-flash\",\n",
    "    \"deepseek-chat\",\n",
    "    \"meta-llama/Meta-Llama-3.1-8B-Instruct\"\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "JJ6r2SH9oIGf"
   },
   "outputs": [],
   "source": [
    "# Available file formats\n",
    "FILE_FORMATS = [\".csv\", \".tsv\", \".jsonl\", \".parquet\", \".arrow\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "B98j45E3vq5g"
   },
   "outputs": [],
   "source": [
    "system_prompt = \"\"\"You are a helpful assistant whose main purpose is to generate datasets for a given business problem based on given schema.\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "lsX16cWfwf6x"
   },
   "outputs": [],
   "source": [
    "def get_env_info(env):\n",
    "  try:\n",
    "    global hf_token, openai_api_key, anthropic_api_key, google_api_key, deepseek_api_key\n",
    "    if env == \"Colab\":\n",
    "      # Colab environment\n",
    "      from google.colab import drive\n",
    "      from google.colab import userdata\n",
    "      hf_token = userdata.get('HF_TOKEN')\n",
    "      openai_api_key = userdata.get('OPENAI_API_KEY')\n",
    "      anthropic_api_key = userdata.get('ANTHROPIC_API_KEY')\n",
    "      google_api_key = userdata.get('GOOGLE_API_KEY')\n",
    "      deepseek_api_key = userdata.get('DEEPSEEK_API_KEY')\n",
    "    elif env == \"Local\":\n",
    "      # Local environment\n",
    "      load_dotenv(override=True)\n",
    "      hf_token = os.getenv('HF_TOKEN')\n",
    "      openai_api_key = os.getenv('OPENAI_API_KEY')\n",
    "      anthropic_api_key = os.getenv('ANTHROPIC_API_KEY')\n",
    "      google_api_key = os.getenv('GOOGLE_API_KEY')\n",
    "      deepseek_api_key = os.getenv('DEEPSEEK_API_KEY')\n",
    "  except Exception as e:\n",
    "      raise Exception(f\"Please check your environment: {str(e)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "2gLUFAwGv29Q"
   },
   "outputs": [],
   "source": [
    "def get_prompt(schema_text, business_problem, nr_records):\n",
    "    prompt = f\"\"\"\n",
    "      The problem is: {business_problem}\n",
    "\n",
    "      Generate {nr_records} rows data in JSONL format, each line a JSON object with the following fields:\n",
    "\n",
    "      {schema_text}\n",
    "\n",
    "      Do NOT repeat column values from one row to another.\n",
    "\n",
    "      Only output valid JSONL.\n",
    "      \"\"\"\n",
    "    return prompt.strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "YZe1FVH8wf84"
   },
   "outputs": [],
   "source": [
    "# --- LLM Interface ---\n",
    "def query(user_prompt, model):\n",
    "    try:\n",
    "        if \"gpt\" in model.lower():\n",
    "            client = OpenAI(api_key=openai_api_key)\n",
    "            messages = [\n",
    "                {\"role\": \"system\", \"content\": system_prompt},\n",
    "                {\"role\": \"user\", \"content\": user_prompt}\n",
    "              ]\n",
    "            response = client.chat.completions.create(\n",
    "                model=model,\n",
    "                messages=messages,\n",
    "                temperature=0.7\n",
    "            )\n",
    "            content = response.choices[0].message.content\n",
    "\n",
    "        elif \"claude\" in model.lower():\n",
    "            client = anthropic.Anthropic(api_key=anthropic_api_key)\n",
    "            response = client.messages.create(\n",
    "                model=model,\n",
    "                messages=[{\"role\": \"user\", \"content\": user_prompt}],\n",
    "                max_tokens=4000,\n",
    "                temperature=0.7,\n",
    "                system=system_prompt\n",
    "            )\n",
    "            content = response.content[0].text\n",
    "        elif \"gemini\" in model.lower():\n",
    "            client = OpenAI(\n",
    "                api_key=google_api_key,\n",
    "                base_url=\"https://generativelanguage.googleapis.com/v1beta/openai/\"\n",
    "            )\n",
    "            messages = [\n",
    "                {\"role\": \"system\", \"content\": system_prompt},\n",
    "                {\"role\": \"user\", \"content\": user_prompt}\n",
    "              ]\n",
    "            response = client.chat.completions.create(\n",
    "                model=model,\n",
    "                messages=messages,\n",
    "                temperature=0.7\n",
    "            )\n",
    "            content = response.choices[0].message.content\n",
    "\n",
    "        elif \"deepseek\" in model.lower():\n",
    "            client = OpenAI(\n",
    "                api_key=deepseek_api_key,\n",
    "                base_url=\"https://api.deepseek.com\"\n",
    "            )\n",
    "            messages = [\n",
    "                {\"role\": \"system\", \"content\": system_prompt},\n",
    "                {\"role\": \"user\", \"content\": user_prompt}\n",
    "              ]\n",
    "            response = client.chat.completions.create(\n",
    "                model=model,\n",
    "                messages=messages,\n",
    "                temperature=0.7\n",
    "            )\n",
    "            content = response.choices[0].message.content\n",
    "\n",
    "        elif \"llama\" in model.lower():\n",
    "            global tokenizer, inputs, llama_model, outputs\n",
    "            messages = [\n",
    "                  {\"role\": \"system\", \"content\": system_prompt},\n",
    "                  {\"role\": \"user\", \"content\": user_prompt}\n",
    "                ]\n",
    "\n",
    "            login(hf_token, add_to_git_credential=True)\n",
    "            quant_config = BitsAndBytesConfig(\n",
    "                load_in_4bit=True,\n",
    "                bnb_4bit_use_double_quant=True,\n",
    "                bnb_4bit_compute_dtype=torch.bfloat16,\n",
    "                bnb_4bit_quant_type=\"nf4\"\n",
    "            )\n",
    "\n",
    "            tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True)\n",
    "            tokenizer.pad_token = tokenizer.eos_token\n",
    "            inputs = tokenizer.apply_chat_template(messages, return_tensors=\"pt\").to(\"cuda\")\n",
    "            if llama_model == None:\n",
    "                llama_model = AutoModelForCausalLM.from_pretrained(model, device_map=\"auto\", quantization_config=quant_config)\n",
    "            outputs = llama_model.generate(inputs, max_new_tokens=4000)\n",
    "\n",
    "            _, _, after = tokenizer.decode(outputs[0]).partition(\"assistant<|end_header_id|>\")\n",
    "            content = after.strip()\n",
    "        else:\n",
    "            raise ValueError(f\"Unsupported model. Use one of {MODELS}\")\n",
    "\n",
    "        # Parse JSONL output\n",
    "        lines = [line.strip() for line in content.strip().splitlines() if line.strip().startswith(\"{\")]\n",
    "        return [json.loads(line) for line in lines]\n",
    "\n",
    "    except Exception as e:\n",
    "        raise Exception(f\"Model query failed: {str(e)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "4WUj-XqM5IYT"
   },
   "outputs": [],
   "source": [
    "# --- Output Formatter ---\n",
    "def save_dataset(records, file_format, filename):\n",
    "    df = pd.DataFrame(records)\n",
    "    if file_format == \".csv\":\n",
    "        df.to_csv(filename, index=False)\n",
    "    elif file_format == \".tsv\":\n",
    "        df.to_csv(filename, sep=\"\\t\", index=False)\n",
    "    elif file_format == \".jsonl\":\n",
    "        with open(filename, \"w\") as f:\n",
    "            for record in records:\n",
    "                f.write(json.dumps(record) + \"\\n\")\n",
    "    elif file_format == \".parquet\":\n",
    "        df.to_parquet(filename, engine=\"pyarrow\", index=False)\n",
    "    elif file_format == \".arrow\":\n",
    "        table = pa.Table.from_pandas(df)\n",
    "        with pa.OSFile(filename, \"wb\") as sink:\n",
    "            with pa.ipc.new_file(sink, table.schema) as writer:\n",
    "                writer.write(table)\n",
    "    else:\n",
    "        raise ValueError(\"Unsupported file format\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "WenbNqrpwf-_"
   },
   "outputs": [],
   "source": [
    "# --- Main Generation Function ---\n",
    "def generate_dataset(schema_text, business_problem, model, nr_records, file_format, save_as, env):\n",
    "    try:\n",
    "        # Validation\n",
    "        if nr_records <= 10:\n",
    "            return \"❌ Error: Number of records must be greater than 10.\", None\n",
    "        if nr_records > 1000:\n",
    "            return \"❌ Error: Number of records must be less than or equal to 1000.\", None\n",
    "\n",
    "        if file_format not in FILE_FORMATS:\n",
    "            return \"❌ Error: Invalid file format.\", None\n",
    "\n",
    "        if not (save_as or save_as.strip() == \"\"):\n",
    "            save_as = f\"default{file_format}\"\n",
    "        elif not save_as.endswith(file_format):\n",
    "            save_as = save_as + file_format\n",
    "\n",
    "        # Load env\n",
    "        get_env_info(env)\n",
    "\n",
    "        # Generate prompt\n",
    "        user_prompt = get_prompt(schema_text, business_problem, nr_records)\n",
    "\n",
    "        # Query model\n",
    "        records = query(user_prompt, model)\n",
    "\n",
    "        if not records:\n",
    "            return \"❌ Error: No valid records generated from the model.\", None\n",
    "\n",
    "        # Save dataset\n",
    "        save_dataset(records, file_format, save_as)\n",
    "\n",
    "        # Create preview\n",
    "        df = pd.DataFrame(records)\n",
    "        preview = df.head(10)  # Show first 10 rows\n",
    "\n",
    "        success_message = f\"✅ Generated {len(records)} records successfully!\\n📁 Saved to: {save_as}\\n📊 \"\n",
    "\n",
    "        return success_message, preview\n",
    "\n",
    "    except Exception as e:\n",
    "        return f\"❌ Error: {str(e)}\", None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "pHiP8ky8wgEb"
   },
   "outputs": [],
   "source": [
    "# --- Gradio Interface ---\n",
    "\n",
    "with gr.Blocks(title=\"Dataset Generator\", theme=gr.themes.Citrus()) as interface:\n",
    "    hf_token = None\n",
    "    openai_api_key = None\n",
    "    anthropic_api_key = None\n",
    "    google_api_key = None\n",
    "    deepseek_api_key = None\n",
    "    tokenizer = None\n",
    "    inputs = None\n",
    "    llama_model = None\n",
    "    outputs = None\n",
    "\n",
    "    gr.Markdown(\"# Dataset Generator\")\n",
    "    gr.Markdown(\"Generate synthetic datasets using AI models\")\n",
    "\n",
    "    with gr.Row():\n",
    "        with gr.Column(scale=2):\n",
    "            schema_input = gr.Textbox(\n",
    "                label=\"Schema\",\n",
    "                value=DEFAULT_SCHEMA_TEXT,\n",
    "                lines=15,\n",
    "                placeholder=\"Define your dataset schema here... Please follow this format: Field_Name, Field_Type, Field Example\"\n",
    "            )\n",
    "\n",
    "            business_problem_input = gr.Textbox(\n",
    "                label=\"Business Problem\",\n",
    "                value=\"I want to generate restuant records\",\n",
    "                lines=1,\n",
    "                placeholder=\"Enter business problem desciption for the model...\"\n",
    "            )\n",
    "\n",
    "            with gr.Row():\n",
    "                model_dropdown = gr.Dropdown(\n",
    "                    label=\"Model\",\n",
    "                    choices=MODELS,\n",
    "                    value=MODELS[0],\n",
    "                    interactive=True\n",
    "                )\n",
    "\n",
    "                nr_records_input = gr.Number(\n",
    "                    label=\"Number of records\",\n",
    "                    value=27,\n",
    "                    minimum=11,\n",
    "                    maximum=1000,\n",
    "                    step=1\n",
    "                )\n",
    "\n",
    "            with gr.Row():\n",
    "                save_as_input = gr.Textbox(\n",
    "                      label=\"Save as\",\n",
    "                      value=\"restaurant_dataset\",\n",
    "                      placeholder=\"Enter filename (extension will be added automatically)\"\n",
    "                  )\n",
    "\n",
    "                file_format_dropdown = gr.Dropdown(\n",
    "                    label=\"File format\",\n",
    "                    choices=FILE_FORMATS,\n",
    "                    value=FILE_FORMATS[0],\n",
    "                    interactive=True\n",
    "                )\n",
    "\n",
    "                env_dropdown = gr.Dropdown(\n",
    "                    label=\"Environment\",\n",
    "                    choices=[\"Colab\", \"Local\"],\n",
    "                    value=\"Colab\",\n",
    "                    interactive=True\n",
    "                )\n",
    "\n",
    "\n",
    "\n",
    "            generate_btn = gr.Button(\"🚀 Generate\", variant=\"secondary\", size=\"lg\")\n",
    "\n",
    "        with gr.Column(scale=1):\n",
    "            output_status = gr.Textbox(\n",
    "                label=\"Status\",\n",
    "                lines=4,\n",
    "                interactive=False\n",
    "            )\n",
    "\n",
    "            output_preview = gr.Dataframe(\n",
    "                label=\"Preview (First 10 rows)\",\n",
    "                interactive=False,\n",
    "                wrap=True\n",
    "            )\n",
    "\n",
    "    # Connect the generate button\n",
    "    generate_btn.click(\n",
    "        fn=generate_dataset,\n",
    "        inputs=[\n",
    "            schema_input,\n",
    "            business_problem_input,\n",
    "            model_dropdown,\n",
    "            nr_records_input,\n",
    "            file_format_dropdown,\n",
    "            save_as_input,\n",
    "            env_dropdown\n",
    "        ],\n",
    "        outputs=[output_status, output_preview]\n",
    "    )\n",
    "\n",
    "    gr.Markdown(\"\"\"\n",
    "    ### 📝 Instructions:\n",
    "    1. **Schema**: Define the structure of your dataset (pre-filled with restaurant schema)\n",
    "    2. **Business problem**: User prompt to guide the AI model\n",
    "    3. **Model**: Choose between GPT, Claude, Gemini, DeepSeek or Llama models\n",
    "    4. **Number of records**: Number of records to generate (minimum 11)\n",
    "    5. **File format**: Choose output format (.csv, .tsv, .jsonl, .parquet, .arrow)\n",
    "    6. **Save as**: Filename (extension added automatically)\n",
    "    7. Click **Generate** to create your dataset\n",
    "\n",
    "    ### 🔧 Requirements:\n",
    "    - For local mode, set up HF token and API keys in `.env` file (`OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, `GOOGLE_API_KEY`, `DEEPSEEK_API_KEY`, `HF_TOKEN`)\n",
    "    - For colab mode, set up HF token and API keys in Colab secret section (`OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, `GOOGLE_API_KEY`, `DEEPSEEK_API_KEY`, `HF_TOKEN`)\n",
    "    \"\"\")\n",
    "\n",
    "interface.launch(debug=True)\n",
    "\n",
    "del tokenizer, inputs, llama_model, outputs\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
